#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright (c) 2024 Song Yan 843558935@qq.com. All rights reserved.
# This code is licensed under the GPL-2.0+ license. See LICENSE file in the project root for full license information.

import logging
from invoke import task

from tasks.constants import DNF_CONF_PATH, DNF_UPGRADE_DATA, RELEASE_VERSION
from tasks.post_upgrade import post_upgrade
from tasks.pre_upgrade import fix_release_version, prepare_upgrade
from tasks.record_status import status_manager
from tasks.utils import safe_run_command

logger = logging.getLogger(__name__)


def exec_upgrade(ctx):
    safe_run_command(ctx, "rpm --rebuilddb")
    safe_run_command(ctx, "dnf -c {DNF_CONF_PATH} clean -q dbcache metadata")
    command = f"dnf -c {DNF_CONF_PATH} -y shell --allowerasing --nobest --releasever={RELEASE_VERSION} --disablerepo='*' --enablerepo='*-oc9' {DNF_UPGRADE_DATA}"
    if safe_run_command(ctx, command).ok:
        logger.info("execute upgrade successfully.")
        return True
    else:
        logger.error("execute upgrade failed.")
        return False



# 升级已安装的组
@task
@status_manager.skip_if_success_decorator
def upgrade_installd_groups(ctx):
    logger.info("Upgrading the installed groups...")
    command = f"dnf -q -c {DNF_CONF_PATH} -y --releasever={RELEASE_VERSION} --disablerepo='*' --enablerepo='*-oc9' grouplist --installed"
    result = safe_run_command(ctx, command)

    if result.ok:
        group_updates = result.stdout.splitlines()
        group_updates = [line.replace('   ', 'groupupdate "') + '"' for line in group_updates if not line.startswith("Installed")]
    else:
        logger.error("Failed to get installed groups")
        return False

    with open (DNF_UPGRADE_DATA, "w") as f:
        f.write("\n".join(group_updates))

    if exec_upgrade(ctx) == False:
        logger.error("Failed to upgrade installed groups")
        return False
    return True

@task(pre=[prepare_upgrade, fix_release_version], post=[upgrade_installd_groups, post_upgrade])
@status_manager.skip_if_success_decorator
def upgrade_system(ctx):
    logger.info("Starting the upgrade to opencloudos {}...".format(RELEASE_VERSION))

    # 修复安装过程中post脚本导致升级阻塞的问题
    safe_run_command(ctx, "chmod -x /usr/bin/firewall-cmd")
    if exec_upgrade(ctx) == False:
        logger.error("Upgrade failed.")
        return False
    safe_run_command(ctx, "chmod +x /usr/bin/firewall-cmd")
    logger.info("Upgrade to opencloudos {} successfully.".format(RELEASE_VERSION))
    
    return True